import os
import sys
import json
import datetime
import numpy as np
import skimage.draw
import imgaug
import tensorflow as tf
import build_data
from PIL import Image

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string(
    'semantic_segmentation_folder',
    './dummy_bottle/newdata/SegmentationClassRaw',
    'Folder containing semantic segmentation annotations.')

tf.app.flags.DEFINE_string(
    'dataset_dir',
    './dummy_bottle/newdata/',
    'Folder containing images.')

tf.app.flags.DEFINE_string(
    'output_tf_dir',
    './dummy_bottle/tfrecord',
    'Folder saving tfrecords.'
)
tf.app.flags.DEFINE_string(
    'segmentation_format',
    'png',
    'segmentation format'
)
tf.app.flags.DEFINE_string('output_seg_dir',
                           './dummy_bottle/newdata/SegmentationClassRaw',
                           'folder to save modified ground truth annotations.')

# if you add/delete element, rebuild model.
bottle_classes = ["脉动", "昆仑山", "零度可乐", "百事可乐", "瓶装王老吉",
                  "雪碧", "海之言", "哇哈哈", "sp100", "农夫山泉", "mini可口可乐",
                  "康师傅冰红茶", "mini青岛啤酒", "青岛啤酒", "可口可乐"]
source_dataset_name_id_mapping = {name: id for id, name in enumerate(bottle_classes, start=1)}


class BottleDataset:
    def __init__(self, *args, **kargs):
        self._image_ids = []
        self.image_info = []
        # Background is always the first class
        self.class_info = [{"source": "", "id": 0, "name": "BG"}]
        self.source_class_ids = {}
        self.class_count = {}

    def add_class(self, source, class_id, class_name):
        assert "." not in source, "Source name cannot contain a dot"
        # Does the class exist already?
        for info in self.class_info:
            if info['source'] == source and info["id"] == class_id:
                # source.class_id combination already available, skip
                return
        # Add the class
        self.class_info.append({
            "source": source,
            "id": class_id,
            "name": class_name,
        })

    def add_image(self, source, image_id, path, **kwargs):
        image_info = {
            "id": image_id,
            "source": source,
            "path": path,
        }
        image_info.update(kwargs)
        self.image_info.append(image_info)
        return len(self.image_info) - 1

    def load_bottles(self, dataset_dir, subsets):
        image_reader = build_data.ImageReader('jpeg', channels=3)
        for i, c in enumerate(bottle_classes, start=1):
            self.add_class("OwnBottle", i, c)
        subsets = [subsets] if isinstance(subsets, str) else subsets
        _NUM_SHARDS = 1
        # 每个subset是一个人打标的图片
        for i, subset in enumerate(subsets):
            subset_dir = os.path.join(dataset_dir, subset)
            annotation_json = json.load(open(os.path.join(subset_dir, 'via_region_data.json'), 'r', encoding='utf-8'))
            annotations = list(annotation_json.values())

            # 去掉没有标注区域的图片
            annotations = [anno for anno in annotations if anno['regions']]
            shard_id = 0
            dataset_name = subset if subset == 'val' else 'train'
            output_filename = os.path.join(FLAGS.output_tf_dir,
                                           '%s-%05d-of-%05d.tfrecord' % (dataset_name, shard_id, _NUM_SHARDS))
            with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
                # Add images
                for index, a in enumerate(annotations):
                    sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
                        index + 1, len(annotations), shard_id))
                    sys.stdout.flush()
                    # Get the x, y coordinaets of points of the polygons that make up
                    # the outline of each object instance. These are stores in the
                    # shape_attributes (see json format above)
                    # The if condition is needed to support VIA versions 1.x and 2.x.
                    if type(a['regions']) is dict:
                        polygons = [r['shape_attributes'] for r in a['regions'].values()]
                        polygons_class_names = [r["region_attributes"]["class"] for r in a['regions'].values()]
                    else:
                        polygons = [r['shape_attributes'] for r in a['regions']]
                        polygons_class_names = [r["region_attributes"]["class"] for r in a['regions']]

                    # load_mask() needs the image size to convert polygons to masks.
                    # Unfortunately, VIA doesn't include it in JSON, so we must read
                    # the image. This is only managable since the dataset is tiny.
                    image_path = os.path.join(subset_dir, a['filename'])
                    image_data = tf.gfile.FastGFile(image_path, 'rb').read()
                    height, width = image_reader.read_image_dims(image_data)
                    polygons_source_class_id = [source_dataset_name_id_mapping[name] for name in polygons_class_names]

                    image_index = self.add_image(
                        # "balloon",  #source
                        # a['regions'][0]['region_attributes']['class'],
                        "OwnBottle",
                        image_id=a['filename'],  # use file name as a unique image id
                        path=image_path,
                        width=width, height=height,
                        polygons=polygons, polygons_source_class_id=polygons_source_class_id)

                    image_mask = self.load_mask(image_index)
                    filename = os.path.splitext(os.path.basename(image_path))[0] + "_" + str(index)
                    self._save_annotation(image_mask, os.path.join(
                         FLAGS.output_seg_dir,
                         filename + '.' + FLAGS.segmentation_format))

                    seg_filename = os.path.join(
                        FLAGS.output_seg_dir,
                        filename + '.' + FLAGS.segmentation_format)
                    seg_data = tf.gfile.FastGFile(seg_filename, 'rb').read()

                    # add one instance
                    # Convert to tf example
                    example = build_data.image_seg_to_tfexample(
                        image_data, a['filename'], height, width, seg_data=seg_data
                    )
                    tfrecord_writer.write(example.SerializeToString())

            sys.stdout.write('\n')
            sys.stdout.flush()

    def _save_annotation(self, annotation, filename):
        """Saves the annotation as png file.

        Args:
          annotation: Segmentation annotation.
          filename: Output filename.
        """
        pil_image = Image.fromarray(annotation.astype(dtype=np.uint8))
        with tf.gfile.Open(filename, mode='w') as f:
            pil_image.save(f, 'PNG')

    def prepare(self, class_map=None):
        """Prepares the Dataset class for use.

        TODO: class map is not supported yet. When done, it should handle mapping
              classes from different datasets to the same class ID.
        """

        def clean_name(name):
            """Returns a shorter version of object names for cleaner display."""
            return ",".join(name.split(",")[:1])

        # Build (or rebuild) everything else from the info dicts.
        self.num_classes = len(self.class_info)
        self.class_ids = np.arange(self.num_classes)
        self.class_names = [clean_name(c["name"]) for c in self.class_info]
        self.num_images = len(self.image_info)
        self._image_ids = np.arange(self.num_images)

        # Mapping from source class and image IDs to internal IDs
        self.class_from_source_map = {"{}.{}".format(info['source'], info['id']): id
                                      for info, id in zip(self.class_info, self.class_ids)}
        self.image_from_source_map = {"{}.{}".format(info['source'], info['id']): id
                                      for info, id in zip(self.image_info, self.image_ids)}

        # Map sources to class_ids they support
        self.sources = list(set([i['source'] for i in self.class_info]))
        self.source_class_ids = {}
        # Loop over datasets
        for source in self.sources:
            self.source_class_ids[source] = []
            # Find classes that belong to this dataset
            for i, info in enumerate(self.class_info):
                # Include BG class in all datasets
                if i == 0 or source == info['source']:
                    self.source_class_ids[source].append(i)

    def map_source_class_id(self, source_class_id):
        """Takes a source class ID and returns the int class ID assigned to it.

        For example:
        dataset.map_source_class_id("coco.12") -> 23
        """
        return self.class_from_source_map[source_class_id]

    def get_source_class_id(self, class_id, source):
        """Map an internal class ID to the corresponding class ID in the source dataset."""
        info = self.class_info[class_id]
        assert info['source'] == source
        return info['id']

    @property
    def image_ids(self):
        return self._image_ids

    def source_image_link(self, image_id):
        """Returns the path or URL to the image.
        Override this to return a URL to the image if it's available online for easy
        debugging.
        """
        return self.image_info[image_id]["path"]

    def load_mask(self, image_id):
        """Load instance masks for the given image.

        Returns:
            masks: A int array of shape [height, width]
        """
        info = self.image_info[image_id]
        masks = np.zeros([info["height"], info["width"]],
                         dtype=np.uint8)
        for i, p in enumerate(info["polygons"]):
            rr, cc = skimage.draw.polygon(p['all_points_y'], p['all_points_x'])
            masks[rr, cc] = info['polygons_source_class_id'][i]
            class_id = info['polygons_source_class_id'][i]
            if class_id not in self.class_count:
                self.class_count[class_id] = 0
            self.class_count[class_id] = self.class_count[class_id] + 1
        return masks.astype(np.uint8)

    def prepare(self, class_map=None):
        """Prepares the Dataset class for use.

        TODO: class map is not supported yet. When done, it should handle mapping
              classes from different datasets to the same class ID.
        """

        def clean_name(name):
            """Returns a shorter version of object names for cleaner display."""
            return ",".join(name.split(",")[:1])

        # Build (or rebuild) everything else from the info dicts.
        self.num_classes = len(self.class_info)
        self.class_ids = np.arange(self.num_classes)
        self.class_names = [clean_name(c["name"]) for c in self.class_info]
        self.num_images = len(self.image_info)
        self._image_ids = np.arange(self.num_images)

        # Mapping from source class and image IDs to internal IDs
        self.class_from_source_map = {"{}.{}".format(info['source'], info['id']): id
                                      for info, id in zip(self.class_info, self.class_ids)}
        self.image_from_source_map = {"{}.{}".format(info['source'], info['id']): id
                                      for info, id in zip(self.image_info, self.image_ids)}

        # Map sources to class_ids they support
        self.sources = list(set([i['source'] for i in self.class_info]))
        self.source_class_ids = {}
        # Loop over datasets
        for source in self.sources:
            self.source_class_ids[source] = []
            # Find classes that belong to this dataset
            for i, info in enumerate(self.class_info):
                # Include BG class in all datasets
                if i == 0 or source == info['source']:
                    self.source_class_ids[source].append(i)


def train():
    """Train the model."""
    # Training dataset.
    dataset_train = BottleDataset()
    dataset_dir = FLAGS.dataset_dir
    dataset_train.load_bottles(dataset_dir,
                               ["b1-C", "b1-B", "b1-C", "1何盛江", "2-吴越", "3刘文豪", "4曾文彬", "5曹海琪", "6李金龙", "7 邹卓", "8zq"])

    mapping={id: name for id, name in enumerate(bottle_classes, start=1)}
    for i, count in dataset_train.class_count.items():
        print(mapping[i] + ':' + str(count))
    # Validation dataset
    dataset_val = BottleDataset()
    dataset_val.load_bottles(dataset_dir, "val")


if __name__ == '__main__':
    train()
